import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from tqdm import tqdm
import os


class TransitionDataset(torch.utils.data.Dataset):
    def __init__(self, M_g, M_c, M_p):
        """
        Args:
            M_g : a concept ordering transition matrix derived from the concept ordering of the given sentences in the Commongen dataset
            M_c : a concept ordering transition matrix derived from the number of paths between concepts in the conceptnet
            M_p : which transition between concepts would be considered
            if_train (boolean): if is in training or validation
        """        
        self.M_g = torch.tensor(M_g, dtype=torch.float)
        self.M_c = torch.tensor(M_c, dtype=torch.float)
        self.M_p = torch.tensor(M_p, dtype=torch.float)
        self.dataset = torch.nonzero(self.M_p == 1)

    def __len__(self):
        return self.dataset.size(0)

    def __getitem__(self, index):
        i, j = self.dataset[index]
        return i, j, self.M_g[i][j], self.M_c[i][j]



class TransitionModel(nn.Module):
    def __init__(self, vocab_size, embed_size, pretrained_emb):
        super(TransitionModel, self).__init__()
        self.vocab_size = vocab_size
        self.embed_size = embed_size
        self.pretrained_emb = pretrained_emb
        self.h_d = 300

        # v, w embedding
        self.v = nn.Embedding.from_pretrained(pretrained_emb, freeze=False)
        self.w = nn.Embedding.from_pretrained(pretrained_emb, freeze=False)

        self.hidden = nn.Sequential(
            nn.Linear(self.embed_size, self.h_d),
            nn.Tanh(),
            nn.Linear(self.h_d, self.h_d)
        )
        self.alpha = 0.1
        
        
    def forward(self, i, j):
        # embedding
        vi = self.v(i)
        wj = self.w(j)

        vi = self.hidden(vi)
        wj = self.hidden(wj)

        o = torch.mul(vi, wj)
        o = torch.sum(o, dim=1)

        return o

    def loss_func(self, output, Mg_p, Mc_p):
        """The loss function in the training

        Args:
            output : o is the output generated by forward function
            Mg_p : The transition probability from concept i to concept j in the Mg
            Mc_p : The transition probability from concept i to concept j in the Mc
        Returns:
            _type_: Mean squared loss 
        """
        mse_loss = nn.MSELoss()
        l_c = mse_loss(output, Mc_p)
        l_g = mse_loss(output, Mg_p)
        loss = self.alpha * l_c + (1 - self.alpha) * l_g
        
        return loss

    def get_matrix(self):
        # Return the learnt matrix
        res_v = self.v.weight.data
        res_w = self.w.weight.data

        res_v = self.hidden(res_v)
        res_w = self.hidden(res_w)

        o = torch.mm(res_v, res_w.T)

        return o.data.cpu().numpy()
